

import numpy as np
import sys
import itertools as it
import matplotlib.pyplot as plt

from channel import *

class eq:
	'''
	Description:
		EQs are implmented in this class.
		FIR(FFE), DFE, VITERBI, FwdBwd are implemented.
		Also, berChecker is implemented.
		It supports NRZ, PAM4 and PAM8.
	'''
	def __init__(self, estTap=None, mod='nrz', snr=200, stateGen=False):
		'''
		Description:
			Normalize estTap(eqSBR, equalizer single-bit-response) to make signal power is equal to 1
			It makes hiddenState, transition matrix, expected value for vitebi and fwdBwd algorithm.
			It will take long time for high-modulation and large SBR. (numState = modNum^len(SBR))
		Params:
			estTap(float, list): Equalizer SBR, =eqSBR.
			mod(str):	modulation
			snr(float): Noise used for forward-backward algorithm
			stateGen(bool): If False, it skips generating hiddeState, transition matrix and expected value for viterbi and fwdBwd. 
				False only when you do not run fwdBwd and viterbi.
		'''
		self.mod_ = mod
		#assert self.mod_ in MOD

		#@@ Check input estmiated SBR and normalize it.
		if estTap is None:
			self.estTap_ = np.array([1])
		else:
			self.estTap_ = np.array(estTap, dtype=float)
			self.estTap_ = self.estTap_/np.linalg.norm(self.estTap_)
		if 0:
			print(self.estTap_)

		if stateGen:
			print("Generating states in eq class...", end=" ", flush=True)
			#@@ Decide #-of-state of viterbi according to modulation
			if self.mod_ == 'nrz':
				self.numState_	= 2**len(self.estTap_)
			elif self.mod_=='pam4':
				self.numState_  = 4**len(self.estTap_)
			elif self.mod_=='pam8':
				self.numState_  = 8**len(self.estTap_)
			else:
				print("Invalid modulation")
		
			if self.mod_ == 'nrz' :
				#@@ State transition matrix
				self.transit_ = np.kron(np.eye(self.numState_//2), np.ones(2))
				self.transit_ = np.append(self.transit_, self.transit_, axis=0)
				#@@ Hidden states generation
				self.states_ = [s for s in it.product([1,-1], repeat = len(self.estTap_))]
				#@@ Expected output according to relation between channel sbr and input sequence
				self.expected_ = []
				#@@ semi-convolution(?) with sbr and sequence
				for i in range(self.numState_):
					self.expected_.append(self.estTap_[::-1].dot(self.states_[i]))
				if 0:
					print(self.states_)
					print(self.transit_)
					print(self.expected_)
			elif self.mod_=='pam4':
				#@@ State transition matrix
				self.transit_ = np.kron(np.eye(self.numState_//4), np.ones(4))
				self.transit_ = np.append(self.transit_, self.transit_, axis=0)
				self.transit_ = np.append(self.transit_, self.transit_, axis=0)
				#@@ Hidden states
				self.states_ = [s for s in it.product([1,1./3,-1./3,-1], repeat = len(self.estTap_))]
				#@@ Expected output due to relation between channel sbr and input sequence
				self.expected_ = []
				#@@ semi-convolution(?) with sbr and sequence
				for i in range(self.numState_):
					self.expected_.append(self.estTap_[::-1].dot(self.states_[i]))
	
				if 0:
					print(self.states_)
					print(self.transit_)
					print(self.expected_)
			elif self.mod_=='pam8':
				#@@ State transition matrix
				self.transit_ = np.kron(np.eye(self.numState_//8), np.ones(8))
				self.transit_ = np.append(self.transit_, self.transit_, axis=0)
				self.transit_ = np.append(self.transit_, self.transit_, axis=0)
				self.transit_ = np.append(self.transit_, self.transit_, axis=0)
				#@@ Hidden states
				self.states_ = [s for s in it.product([1,5./7,3./7,1./7,-1./7,-3./7,-5./7,-1], repeat = len(self.estTap_))]
				#@@ Expected output due to relation between channel sbr and input sequence
				self.expected_ = []
				#@@ semi-convolution(?) with sbr and sequence
				for i in range(self.numState_):
					self.expected_.append(self.estTap_[::-1].dot(self.states_[i]))
	
				if 0:
					print(self.states_)
					print(self.transit_)
					print(self.expected_)
	
				pass
			else:
				print("Invalid modulation")
			#print(self.sigma_)
			print("Done", flush=True)

	def berChecker(self, ref, rxOut, delay=0, offsetStart=0, offsetEnd=1, log=False):
		'''
		Description:
			Check if rxOut(equalized output) is same with ref(tx output).
			Delay must be set properly because equalizer can add some delay to input signal.
			OffsetStart and offsetEnd can be used to ignore some unstable edge of sequences.
		Params:
			ref(float, list)	: reference is tx data. NRZ: -1,1 // PAM4: -1,-1/3,1/3,1 // PAM8: ...
			rxOut(int, list)	: equalized output. NRZ: 0,1 // PAM4: 0,1,2,3 // PAM8: 0,1,2 ... 7
			delay(int)			: delay when comparing ref with rxOut.
			offsetStart(int)	: It will ignore the ber before offsetStart index.
			offsetEnd(int)		: It will ignore the ber after (-offsetEnd) index
			log(bool)			: It will turn on bit error detail during simulation.
		'''


		ref = np.array(ref)

		#@@ Convert ref format to rxOut format. (NRZ: -1,1 -> 0,1 // PAM4: -1,-1/3,1/3,1 -> 0,1,2,3)
		if (self.mod_=='nrz'):
			amp=2.0/1
			ref = np.where( ref < -1+amp/2, int(0),
											int(1)
											)
			pass
		elif (self.mod_=='pam4'):
			amp = 2.0/3
			ref = np.where(	ref < -1+amp/2+amp*0, int(0),
				np.where(	ref < -1+amp/2+amp*1, int(1),
				np.where(	ref < -1+amp/2+amp*2, int(2), 
												int(3)
													)))
		elif (self.mod_=='pam8'):
			amp = 2.0/7
			ref = np.where(	ref < -1+amp/2+amp*0, int(0),
				np.where(	ref < -1+amp/2+amp*1, int(1),
				np.where(	ref < -1+amp/2+amp*2, int(2),
				np.where(	ref < -1+amp/2+amp*3, int(3),
				np.where(	ref < -1+amp/2+amp*4, int(4),
				np.where(	ref < -1+amp/2+amp*5, int(5),
				np.where(	ref < -1+amp/2+amp*6, int(6),
													 int(7)
													)))))))


		rxOut = np.array(rxOut)
		rxOut = rxOut[delay:]
		length = min(len(ref),len(rxOut))
		if (len(rxOut) < len(ref)) : # overlapped viterbi case
			ref = ref[:length]
			rxOut = rxOut[:length]
		#bitErr = abs(ref-rxOut[delay:delay+len(ref)])
		bitErr = abs(ref-rxOut[:len(ref)])
		#print ("rxOut: ",rxOut)
		#print ("bitErr: ",bitErr)
		bitErr = np.where(bitErr==0, 0, 1)
		if log:
			print ("")
			print ("ref rxOut be ")
			for k in range(len(ref)):
				print (ref[k], " ", rxOut[k], " ",bitErr[k])
			print (offsetStart, " ", offsetEnd)

		bitErr = bitErr[offsetStart:-offsetEnd]
		#ber = float(sum(bitErr)/2)/len(bitErr)
		ber = float(sum(bitErr))/len(bitErr)
		#print ("bitErr: ",bitErr)
		return ber, sum(bitErr)

		
	## Viterbi decoder bundle
	# @param blockSize	: Segment size. It must be larger than estTap size to obtain accurate result
	###def viterbiPack(self, rxIn, blockSize=100, v0=None, transit=None):
	###	# divide rxIn into segment
	###	blockDiv = np.ceil(len(rxIn)/blockSize)
	###	rxInSplit = np.array_split(rxIn, blockDiv)
	###	mlSeqPack = np.array([],dtype=int)
	###	# put segment to viterbi algorithm
	###	sys.stdout.write("viterbi processing...")
	###	for k in range(len(rxInSplit)):
	###		if (int(len(rxIn)/blockSize*0.1) is not 0):
	###			if (k%int(len(rxIn)/blockSize*0.1)==0) :
	###				sys.stdout.write("%0.1f%%.."%(float(k)/len(rxIn)*blockSize*100))
	###				sys.stdout.flush()
	###		if (k==0):
	###			(score,mlSeq,V) = self.viterbi(rxInSplit[k])
	###		else:
	###			(score,mlSeq,V) = self.viterbi(rxInSplit[k], v0=V)
	###			V = V.values()
	###			V=np.where(V==max(V),V,-1e6)
	###		mlSeqPack = np.concatenate((mlSeqPack,mlSeq))
	###	print("")

	###	return (score, mlSeqPack, V)

	## Viterbi overlap decoder bundle
	# @param blockSize	: Segment size. It must be larger than estTap size to obtain accurate result
	def viterbiOverlapPack(self, rxIn, blockSize=100, v0=None, transit=None):
		'''
		Description:
			It is wrapper of viterbiOverlap.
			Because viterbiOverlap input is segmented block, wrapper schedules input blocks to put in to viterbiOverlap.
			First, it divides rxIn sequences to piece-wise block.
			After dividing, each segmented block is put to viterbiOverlap in order.
			At the same time, output of each viterOverlap operation is append to 'mlSeqPack'
		Params:
			rxIn(float, list): viterbi input sequences. usually channel output.
			blockSize(int): viterbiOverlap input block size.
			v0(float, list): viterbi branch metric. Size must be hidden state size. If none, it equals to all one.
			transit(float, Mdim-list): transition probability matrix between hidden state.
		'''
		# divide rxIn into segment
		blockDiv = np.ceil(len(rxIn)/blockSize)
		rxInSplit = np.array_split(rxIn, blockDiv)
		mlSeqPack = np.array([],dtype=int)
		# put segment to viterbi algorithm
		sys.stdout.write("viterbi processing...")
		for k in range(len(rxInSplit)):
			if (int(len(rxIn)/blockSize*0.1) is not 0):
				if (k%int(len(rxIn)/blockSize*0.1)==0) :
					sys.stdout.write("%0.1f%%.."%(float(k)/len(rxIn)*blockSize*100))
					sys.stdout.flush()
			if (k==0):
				(score,mlSeq,V,seqOverlap) = self.viterbiOverlap(rxInSplit[k])
				#V = V.values()
			else:
				(score,mlSeq,V,seqOverlap) = self.viterbiOverlap(rxInSplit[k], v0=V, seqPrev=seqOverlap)
				#V = V.values()
				#V=np.where(V==max(V),V,-1e6)
			mlSeqPack = np.concatenate((mlSeqPack,mlSeq))
		print("")

		return (score, mlSeqPack, V)

		
	## Overlappable Viterbi decoder
	# @param self	: The object pointer
	# @param rxIn	: RX input sequences
	# @param transit: State transition matrix (optional)
	# @param v0		: Initial branch matrix of each state
	def viterbiOverlap(self, rxIn, v0=None, transit=None, seqPrev=None):
		'''
		Description:
			Viterbi algorithm core.
			It calculates branch metrics('V') whose size is the number of hidden state.
			Each branch store sequence which has the most largest branch metric.
			This viterbiOverlap decode only first half of the input block because last half block has not enough information.
		Params:
			rxIn(float, list): viterbi input sequence.
			v0(float, list): viterbi branch metric. Size must be hidden state size. If none, it equals to all one.
			transit(float, Mdim-list): transition probability matrix between hidden state.
			seqPrev(int, list): previous viterbi-decoded sequence. It is used for decoding current sequence.
		'''
		v0 = v0 if v0 is not None else np.full(self.numState_,1) 
		transit = transit if transit is not None else self.transit_
		lengthOverlap = int(len(seqPrev[0])) if seqPrev is not None else int(len(rxIn)/2)
		if seqPrev is not None:
			if len(seqPrev.keys()) is not self.numState_:
				print("Error: Number of state mismatch!!")
				sys.exit()
			

		#transit = np.where(transit==0,np.NaN,transit)
		numState = transit.shape[0] if transit is not None else self.numState_

		V = [{}]	#@@ V will store branch metrics.
		seq = {}	#@@ It stores maximum-likelihood sequence according to branch metric.

		#@@ initial branch metrics
		for s in range(numState):
			V[0][s] = v0[s]
	
		# print seq
		# run Viterbi algorithm (dynamic programming)

		for t in range(1, len(rxIn)+1): 
			V.append({})
			newSeq = {}
			#@@ find maximum scoring(branch metric) path for each state at each timestep
			for s in range(numState):
				(score, state) = max( (V[t-1][sPrev] + np.log(transit[sPrev][s] + 1e-15) - (self.expected_[s]-rxIn[t-1])**2
									, sPrev) for sPrev in range(numState) )
				V[t][s] = score
				if (t == 1):
					if (seqPrev is not None):
						newSeq[s] = seqPrev[state] + [s]
					else:
						newSeq[s] = [s]
				else:
					newSeq[s] = seq[state] + [s]
			seq = newSeq	 #@@ each seq has max scoring(branch metric) path at the state

	
		#@@ find maximum scoring path at final timestep	
		(score, state) = max( (V[t][s], s) for s in range(numState))
	
		viterbiOut = np.array(seq[state][:-lengthOverlap])
		seqOverlap = {}
		for s in range(numState):
			seqOverlap[s] = seq[s][-lengthOverlap:]
		if 0:
			print ("")
			print (seq[state])
			print (viterbiOut)

		#@@ Mapping (hidden state index) to (actual modulation symbol)
		if self.mod_ == 'nrz':
			#print ("1:",len(viterbiOut))
			#print ("2:",lengthOverlap)
			viterbiOut = np.where(viterbiOut%2==0, 1, 0)
		elif self.mod_ == 'pam4':
			viterbiOut = np.where(viterbiOut%4==0, 	int(3),
						np.where(viterbiOut%4==1, 	int(2),
						np.where(viterbiOut%4==2, 	int(1),
													int(0)
													)))
		elif self.mod_ == 'pam8':
			viterbiOut = np.where(viterbiOut%8==0, 	int(7),
						np.where(viterbiOut%8==1, 	int(6),
						np.where(viterbiOut%8==2, 	int(5),
						np.where(viterbiOut%8==3, 	int(4),
						np.where(viterbiOut%8==4, 	int(3),
						np.where(viterbiOut%8==5, 	int(2),
						np.where(viterbiOut%8==6, 	int(1),
													int(0)
													)))))))
			

		else:
			print("Invalid modulation")
	
		return (score, viterbiOut , V[-1], seqOverlap)
	## Viterbi decoder
	# @param self	: The object pointer
	# @param rxIn	: RX input sequences
	# @param transit: State transition matrix (optional)
	# @param v0		: Initial branch matrix of each state
###	def viterbi(self, rxIn, v0=None, transit=None):
###		v0 = v0 if v0 is not None else np.full(self.numState_,1) 
###		transit = transit if transit is not None else self.transit_
###
###		#transit = np.where(transit==0,np.NaN,transit)
###		numState = transit.shape[0] if transit is not None else self.numState_
###		rxInCopy = rxIn[:len(rxIn)-len(self.estTap_)+1].copy()
###
###		V = [{}]
###		seq = {}
###
###		# initial branch matrix
###		for s in range(numState):
###			V[0][s] = v0[s]
###			seq[s] = [s]
###	
###		#print seq
###		# run Viterbi algorithm (dynamic programming)
###
###		# for initial case
###		# not initial case
###		for t in range(1, len(rxIn)+1): 
###			V.append({})
###			newSeq = {}
###			# find maximum scoring path for each state at each timestep
###			for s in range(numState):
###				(score, state) = max( (V[t-1][sPrev] + np.log(transit[sPrev][s] + 1e-15) - (self.expected_[s]-rxIn[t-1])**2
###									, sPrev) for sPrev in range(numState) )
###				V[t][s] = score
###				if (t == 1):
###					newSeq[s] = [s]
###				else:
###					newSeq[s] = seq[state] + [s]
###			seq = newSeq
###
###	
###		# find maximum scoring path at final timestep	
###		(score, state) = max( (V[t][s], s) for s in range(numState))
###	
###		if self.mod_ == 'nrz':
###			mlSeq = (((np.array(seq[state]) & 1) + 1) % 2)*2 - 1
###			return (score, mlSeq, V[-1])
###		else:
###			print("Invalid modulation")
	

	
	def firDfe(self, rxIn, ffeTapNum=None, maxTapNum=None, dfeTapNum=None):
		'''
		Description:
			It is combination equalizer of FIR+DFE.
			firV2 firstly equalize input with given ffeTapNumm, maxTapNum.
			maxTapNum means the maximum cursor of FIR output SBR.
			The larger maxTapNum for given ffeTapNum, it concentrate more to cancel pre-cursor. The lower maxTapNum, more to cancel post-cursor.
		Params:
			rxIn(float, list): input sequence to be equalized.
			ffeTapNum(int) : The number of ffe tap
			maxTapNum(int) : The maximum cursor of ffe filter tap
			dfeTapNum(int) : The number of dfe tap
		'''
		sbr = self.estTap_
		fir, firOut, maxTapNum = self.firV2(rxIn, ffeTapNum=ffeTapNum, maxTapNum=maxTapNum)
		convFirSbr = np.convolve(fir,sbr)
		dfeOutReal, seq, delay = self.dfe(firOut, dfeTapNum=dfeTapNum, sbrOvrd=convFirSbr)
		if 0:
			plt.plot(rxIn,'-*', label='rxIn')
			plt.plot(firOut,'-v',label='firOut')
			plt.plot(sbr, '-h', label='sbr')
			plt.plot(convFirSbr,'-x',label='convFirSBr')
			plt.grid(True)
			plt.legend(loc='best')
			plt.show()
	

		return dfeOutReal, seq, delay
		




#	def firDfe(self, rxIn, ffeTapNum=None, maxTapNum=None, dfeTapNum=None):
#		sbr = self.estTap_
#		ffeTapNum = ffeTapNum if ffeTapNum is not None else len(self.estTap_)
#		maxTapNum = maxTapNum if maxTapNum is not None else list(sbr).index(max(sbr))+1
#		H_base = sbr[::-1]
#		H_base = np.concatenate((np.zeros(ffeTapNum-1), H_base, np.zeros(ffeTapNum-1)))
#		H = []
#		Y = np.zeros(len(sbr)+ffeTapNum-1)
#		Y[maxTapNum] = 1.0
#		#print Y
#		Y=np.matrix(Y)
#		#print (f'Y_des: {Y}')
#		for k in range(len(H_base)-ffeTapNum+1):
#			H.append(H_base[len(H_base)-ffeTapNum-k:len(H_base)-k])
#		H=np.matrix(H)
#		#print (f'H: {H}')
#		HT=H.getH()
#		W = (HT*H).getI()*HT*Y.getH()
#		Y_ = H*W
#		fir = np.array(W.getH())[0]
#		fir = fir/np.linalg.norm(fir) 
#		print (f"fir tap: {fir}\n")
#
#
#
#		firOut = np.convolve(rxIn,fir)
#		convFirSbr = np.convolve(fir,sbr)
#
#		dfeOutReal, seq, delay = self.dfe(firOut, dfeTapNum=dfeTapNum, sbrOvrd=convFirSbr)
#
#
#		if 0:
#			plt.plot(rxIn,'-*', label='rxIn')
#			plt.plot(firOut,'-v',label='firOut')
#			plt.plot(convFirSbr,'-x',label='convFirSBr')
#			plt.grid(True)
#			plt.legend(loc='best')
#			plt.show()
#	
#
#
#		return dfeOutReal, seq, delay

	def firV2(self, rxIn, ffeTapNum=None, maxTapNum=None):
		sbr = self.estTap_
		ffeTapNum = ffeTapNum if ffeTapNum is not None else len(self.estTap_)
		maxTapNum = maxTapNum if maxTapNum is not None else list(sbr).index(max(sbr))+1
		H_base = sbr[::-1]
		H_base = np.concatenate((np.zeros(ffeTapNum-1), H_base, np.zeros(ffeTapNum-1)))
		H = []
		Y = np.zeros(len(sbr)+ffeTapNum-1)
		Y[maxTapNum] = 1.0
		#print Y
		Y=np.matrix(Y)
		#print (f'Y_des: {Y}')
		for k in range(len(H_base)-ffeTapNum+1):
			H.append(H_base[len(H_base)-ffeTapNum-k:len(H_base)-k])
		H=np.matrix(H)
		#print (f'H: {H}')
		HT=H.getH()
		W = (HT*H).getI()*HT*Y.getH()
		Y_ = H*W
		fir = np.array(W.getH())[0]
		fir = fir/np.linalg.norm(fir) 
		if 0:
			print (f"fir tap: {fir}\n")
		firOut = np.convolve(rxIn,fir)
		oneAmp = np.sum(fir)*np.sum(sbr)
		#print (f"oneAmp: {oneAmp}\n")
		if 0:
			plt.plot(firOut,'-v',label='firOut')
			
		return fir, firOut, maxTapNum



	def fir(self, rxIn, ffeTapNum=None, maxTapNum=None):
		sbr = self.estTap_
		ffeTapNum = ffeTapNum if ffeTapNum is not None else len(self.estTap_)
		maxTapNum = maxTapNum if maxTapNum is not None else list(sbr).index(max(sbr))+1
		H_base = sbr[::-1]
		H_base = np.concatenate((np.zeros(ffeTapNum-1), H_base, np.zeros(ffeTapNum-1)))
		H = []
		Y = np.zeros(len(sbr)+ffeTapNum-1)
		Y[maxTapNum] = 1.0
		#print Y
		Y=np.matrix(Y)
		#print (f'Y_des: {Y}')
		for k in range(len(H_base)-ffeTapNum+1):
			H.append(H_base[len(H_base)-ffeTapNum-k:len(H_base)-k])
		H=np.matrix(H)
		#print (f'H: {H}')
		HT=H.getH()
		W = (HT*H).getI()*HT*Y.getH()
		Y_ = H*W
		fir = np.array(W.getH())[0]
		fir = fir/np.linalg.norm(fir) 
		#print (f"fir tap: {fir}\n")
		firOut = np.convolve(rxIn,fir)
		oneAmp = np.sum(fir)*np.sum(sbr)
		#print (f"oneAmp: {oneAmp}\n")
		if 0:
			plt.plot(firOut,'-v',label='firOut')
			
		if self.mod_ == 'nrz':
			#seq = np.where(firOut>0,1,-1)
			amp = 2.0/1
			seq = np.where (firOut < (-1+amp/2)*oneAmp, int(0),
												int(1)
												)
		elif self.mod_ == 'pam4':
			amp = 2.0/3
			seq = np.where(	firOut < (-1+amp/2+amp*0)*oneAmp, int(0),
				np.where(	firOut < (-1+amp/2+amp*1)*oneAmp, int(1),
				np.where(	firOut < (-1+amp/2+amp*2)*oneAmp, int(2), 
															 int(3)
															)))
		elif self.mod_ == 'pam8':
			amp = 2.0/7
			seq = np.where(	firOut < (-1+amp/2+amp*0)*oneAmp, int(0),
				np.where(	firOut < (-1+amp/2+amp*1)*oneAmp, int(1),
				np.where(	firOut < (-1+amp/2+amp*2)*oneAmp, int(2),
				np.where(	firOut < (-1+amp/2+amp*3)*oneAmp, int(3),
				np.where(	firOut < (-1+amp/2+amp*4)*oneAmp, int(4),
				np.where(	firOut < (-1+amp/2+amp*5)*oneAmp, int(5),
				np.where(	firOut < (-1+amp/2+amp*6)*oneAmp, int(6),
															 int(7)
															)))))))

		return seq, maxTapNum

	def dfe(self, rxIn, dfeTapNum=None, sbrOvrd=None):
		if sbrOvrd is not None:
			sbr = sbrOvrd
		else:
			sbr = self.estTap_
		maxidx = list(sbr).index(max(sbr))
		dfeTapNum = dfeTapNum if dfeTapNum is not None else len(sbr[maxidx:])-1
		if (dfeTapNum >= len(sbr[maxidx:])):
			print(f"dfeTapNum({dfeTapNum}) must be smaller than sbr[maxidx:]  ({len(sbr[maxidx:])})")
			sys.exit()

		#print (f'dfeTapNum: {dfeTapNum}')
		dfeOut = np.zeros(len(sbr[maxidx:]))
		dfeOut = list(dfeOut)
		#print dfeOut
		dfeOutReal = []
		oneAmp = np.max(sbr)
		oneConst= np.sum(sbr)
		if 0:
			print(f'oneAmp: {oneAmp}')
			print(f'oneConst: {oneConst}')
		for t in range(len(rxIn)):
			dfeTemp = rxIn[t]
			#for k in range(len(sbr[maxidx:])-1):
			for k in range(dfeTapNum):
				#print(k)
				#print(sbr[maxidx+k+1])
				#print(dfeOut[-k-1])
				dfeTemp -= sbr[maxidx+k+1]*dfeOut[-k-1]
				#print ("t: ",t, "sbr: ",sbr[maxidx+k+1], "dfeOut: ", dfeOut[-k-1])
				#print ("dfeOut: ",dfeOut)
			if (self.mod_=='nrz'):
				dfeOut.append(np.sign(dfeTemp))
			elif (self.mod_=='pam4'):
				amp = 2.0/3
				dfeOut.append(np.where(	dfeTemp < (-1+amp/2+amp*0)*oneAmp, (-1+amp*0),
							np.where(	dfeTemp < (-1+amp/2+amp*1)*oneAmp, (-1+amp*1),
							np.where(	dfeTemp < (-1+amp/2+amp*2)*oneAmp, (-1+amp*2),
																  1
																  ))))
			elif (self.mod_=='pam8'):
				amp = 2.0/7
				dfeOut.append(np.where(	dfeTemp	< (-1+amp/2+amp*0)*oneAmp, (-1+amp*0),
							np.where(	dfeTemp < (-1+amp/2+amp*1)*oneAmp, (-1+amp*1),
							np.where(	dfeTemp < (-1+amp/2+amp*2)*oneAmp, (-1+amp*2),
							np.where(	dfeTemp < (-1+amp/2+amp*3)*oneAmp, (-1+amp*3),
							np.where(	dfeTemp < (-1+amp/2+amp*4)*oneAmp, (-1+amp*4),
							np.where(	dfeTemp < (-1+amp/2+amp*5)*oneAmp, (-1+amp*5),
							np.where(	dfeTemp < (-1+amp/2+amp*6)*oneAmp, (-1+amp*6),
																	1
																	))))))))
			dfeOutReal.append(dfeTemp)
		dfeOut = dfeOut[dfeTapNum:]
		dfeOut = np.array(dfeOut)
		delay = maxidx
		if (self.mod_=='nrz'):
			amp = 2.0/1
			seq = np.where(	dfeOutReal < (-1+amp/2+amp*0)*oneAmp, int(0),
													int(1)
													)
		elif self.mod_ == 'pam4':
			amp = 2.0/3
			if 0:
				print (f'(-1+amp/2+amp*0)*oneAmp: {(-1+amp/2+amp*0)*oneAmp}')
				print (f'(-1+amp/2+amp*1)*oneAmp: {(-1+amp/2+amp*1)*oneAmp}')
				print (f'(-1+amp/2+amp*2)*oneAmp: {(-1+amp/2+amp*2)*oneAmp}')
			seq = np.where(	dfeOutReal < (-1+amp/2+amp*0)*oneAmp, int(0),
				np.where(	dfeOutReal < (-1+amp/2+amp*1)*oneAmp, int(1),
				np.where(	dfeOutReal < (-1+amp/2+amp*2)*oneAmp, int(2), 
													 int(3)
													)))
		elif self.mod_ == 'pam8':
			amp = 2.0/7
			seq = np.where(	dfeOutReal < (-1+amp/2+amp*0)*oneAmp, int(0),
				np.where(	dfeOutReal < (-1+amp/2+amp*1)*oneAmp, int(1),
				np.where(	dfeOutReal < (-1+amp/2+amp*2)*oneAmp, int(2),
				np.where(	dfeOutReal < (-1+amp/2+amp*3)*oneAmp, int(3),
				np.where(	dfeOutReal < (-1+amp/2+amp*4)*oneAmp, int(4),
				np.where(	dfeOutReal < (-1+amp/2+amp*5)*oneAmp, int(5),
				np.where(	dfeOutReal < (-1+amp/2+amp*6)*oneAmp, int(6),
													 int(7)
													)))))))
		if 0:
			if (self.mod_=='nrz'):
				plt.plot([oneAmp]*len(rxIn))
				plt.plot([-oneAmp]*len(rxIn))
			elif (self.mod_=='pam4'):
				plt.plot([(-1+amp*0)*oneAmp]*len(rxIn))
				plt.plot([(-1+amp*1)*oneAmp]*len(rxIn))
				plt.plot([(-1+amp*2)*oneAmp]*len(rxIn))
				plt.plot([oneAmp]*len(rxIn))
			plt.plot(rxIn,'-o',label='rxIn')
			plt.plot(dfeOutReal,'-h',label='dfeOutReal')
			plt.legend(loc='best')
			plt.show()
	
		return dfeOutReal, seq, delay

	def gaus(self, x, mu, sigma):
		y = (1 / np.sqrt(2 * np.pi * sigma**2)) * np.exp(-(x-mu)**2 / (2 * sigma**2))
		return y
	
	def fwd(self, rxIn, lengthFwd, snr, fwdPrev=None, transit=None, seqPrev=None):
		sigma = 10**(float(-snr) / 20)# for Forward-Backward Algorithm
		sbr = self.estTap_
		maxidx = list(sbr).index(max(sbr))
		
		fwdPrev= fwdPrev if fwdPrev is not None else np.full(self.numState_,1) 
		transit = transit if transit is not None else self.transit_

		#transit = np.where(transit==0,np.NaN,transit)
		numState = transit.shape[0] if transit is not None else self.numState_

		fwd = []
		matO = []
		seq = [np.nan]*(lengthFwd-1)
		prob = [np.nan]*(lengthFwd-1)

		fwdFirst = np.matrix([[1,1,1,1]])
		bwdFisrt = np.matrix([[1],[1],[1],[1]])
		matT = np.matrix(transit)


		sys.stdout.write("fwd processing...")
		for t in range(len(rxIn)-lengthFwd+1):
			rxSegment = rxIn[t:t+lengthFwd]
			if int(len(rxIn)*0.01) is not 0:
				if (t%int(len(rxIn)*0.01)==0) :
					sys.stdout.write("%0.1f%%.."%(float(t)/len(rxIn)*100))
					sys.stdout.flush()
			if (t==0):
				for k in range(lengthFwd):
					oList = list(map(self.gaus,[rxSegment[k]]*numState, self.expected_, [sigma]*numState))
					oDiag = np.diag(oList)
					matO.append(np.matrix(oDiag))
					#print ('')
					#print (rxSegment[k])
					#print (matO[k])
				for k in range(lengthFwd):
					fwd.append(np.matmul(matT,matO[k]))
				#print(fwd)
				#print(f'oList : {oList}')
				#sys.exit()

			else:
				### matO update
				matO[:-1] = matO[1:]
				oList = list(map(self.gaus,[rxSegment[-1]]*numState, self.expected_, [sigma]*numState))
				oDiag = np.diag(oList)
				matO[-1] = np.matrix(oDiag)

				### fwd update
				fwd[:-1] = fwd[1:]
				fwd[-1]  = np.matmul(matT,matO[lengthFwd-1])
				

			## Calculate Fwd & bwd
			fwdCal = np.matrix(np.ones(numState))
			for k, fwdUnit in enumerate(fwd):
				fwdCal = fwdCal*fwdUnit 

			fwdBwdVal = np.array(fwdCal.T) ## Fwd only
			fwdBwdVal = fwdBwdVal/np.sum(fwdBwdVal)
			fwdBwdValOne = 0
			for k, val in enumerate(fwdBwdVal):
				fwdBwdValOne += (-(k%2)+1)*val	## Calculate probability of One
			decBit = -(np.argmax(fwdBwdVal)%2)*2+1
			seq.append(decBit)
			prob.append(fwdBwdValOne)
			#print(fwdBwdVal)
			#print(decBit)
			#print(f"{t}")
		seq = seq 
		prob = prob
		#print(seq)
		#print(prob)

		return seq, prob

	


	def fwdBwd(self, rxIn, lengthFwdBwd, snr, fwdPrev=None, bwdPrev=None, transit=None, seqPrev=None):
		'''
		Description:
			fwdBwd inference a target data based on serial data.
			A target data is selected to the center index of the lengthFwdBwd.
			FwdBwd algorithm is running as follow.
				1. Compute forward probability based on sequence before target data.
				2. ex) T*O1 * T*O2 * T*O3 * T*O_target, where T is transition probability matrix and O# is observing proability
				3. Compute backward probability based on sequence after target data.
				4. ex) T*O5 * T*O6 * T*O7 *T*O8
				5. Point wise multiplication with fwd and bwd prob.
				6. ex) fwd_prob .* bwd_prob
		Params:
			rxIn(float, list): Input sequence to be eqaulized.
			lengthFwdBwd(int) : Size of fwdBwd calculation unit.
			snr(float): SNR is used to calculate noise distribution which is for probability calculation for observing probability.
			fwdPrev: Not used
			bwdPrev: Not used
			transit: Transition probability matrix from hidden state to hidden state
			seqPrev: Not used
		'''

		#@@ Calculate sigma from snr
		sigma = 10**(float(-snr) / 20)# for Forward-Backward Algorithm
		sbr = self.estTap_
		#maxidx = list(sbr).index(max(sbr)) 
		
		#fwdPrev = fwdPrev if fwdPrev is not None else np.full(self.numState_,1) 
		#bwdPrev = bwdPrev if bwdPrev is not None else np.full(self.numState_,1) 
		transit = transit if transit is not None else self.transit_
		#@@ LengthFwd is treated as half of lengthFwdBwd
		lengthFwd = int((lengthFwdBwd+1)/2)
		#@@ LengthBwd is rest of the lengthFwdBwd except lengthFwd
		lengthBwd = lengthFwdBwd - lengthFwd
		#print(f"fwd: {lengthFwd}, bwd: {lengthBwd}")
		#lengthOverlap = int(len(seqPrev[0])) if seqPrev is not None else int(len(rxIn)/2)

		#transit = np.where(transit==0,np.NaN,transit)
		numState = transit.shape[0] if transit is not None else self.numState_

		#@@ Initialize variable
		fwd = []
		bwd = []
		matO = []

		#@@ Set seq to garbage before first decoding bit to match with input data length
		seq = [np.nan]*(lengthFwd-1)
		#@@ Set probability also
		if self.mod_ == 'nrz':
			prob = [[np.nan]*2]*(lengthFwd-1)
		elif self.mod_ == 'pam4':
			prob = [[np.nan]*4]*(lengthFwd-1)
		elif self.mod_ == 'pam8':
			prob = [[np.nan]*8]*(lengthFwd-1)

		#fwdFirst = np.matrix([[1,1,1,1]])
		#bwdFisrt = np.matrix([[1],[1],[1],[1]])

		#@@ Transition matrix
		matT = np.matrix(transit)

		# initial fwd and bwd list
		#fwd = np.zeros(lengthFwd)
		#bwd = np.zeros(lengthBwd)
		#O = np.zeros(lengthFwdBwd)
		#print (transit)
		#print (self.expected_)

	
		# print seq
		# run Viterbi algorithm (dynamic programming)

		#@@ FwdBwd start for whole rxIn sequence
		sys.stdout.write("fwdbwd processing...")
		for t in range(len(rxIn)-lengthFwdBwd+1):
			#@@ Segmenting input sequence for given lengthFwdBwd
			rxSegment = rxIn[t:t+lengthFwdBwd]
			if int(len(rxIn)*0.01) is not 0:
				if (t%int(len(rxIn)*0.01)==0) :
					sys.stdout.write("%0.1f%%.."%(float(t)/len(rxIn)*100))
					sys.stdout.flush()
			#@@ At first time, calculate whole observe matrix and 
			if (t==0):
				for k in range(lengthFwdBwd):
					oList = list(map(self.gaus,[rxSegment[k]]*numState, self.expected_, [sigma]*numState))
					oDiag = np.diag(oList)
					matO.append(np.matrix(oDiag))
					#print (rxSegment[k])
					#print (matO[0])
				#@@ Calculate T*O for each data point
				for k in range(lengthFwd):
					fwd.append(np.matmul(matT,matO[k]))
				#print(fwd)
				for k in range(lengthBwd):
					bwd.append(matT*matO[k+lengthFwd])
				#print(f'oList : {oList}')
				#sys.exit()

			#@@ Except the first time, calculate just one O corresponding T*O and reuse others.
			else:
				### matO update
				matO[:-1] = matO[1:]
				oList = list(map(self.gaus,[rxSegment[-1]]*numState, self.expected_, [sigma]*numState))
				oDiag = np.diag(oList)
				matO[-1] = np.matrix(oDiag)
				#print (rxSegment[-1])

				### fwd update
				fwd[:-1] = fwd[1:]
				fwd[-1]  = np.matmul(matT,matO[lengthFwd-1])
				
				### bwd update
				bwd[:-1] = bwd[1:]
				bwd[-1] = np.matmul(matT,matO[-1])

			#@@ Calculate Fwd & bwd probability
			fwdCal = np.matrix(np.ones(numState))
			for k, fwdUnit in enumerate(fwd):
				fwdCal = fwdCal*fwdUnit 

			bwdCal = np.matrix(np.ones(numState)).T
			for k, bwdUnit in reversed(list(enumerate(bwd))):
				bwdCal = bwdUnit*bwdCal

			#@@ Point-wise multiplication with fwd and bwd prob.
			fwdBwdVal = np.array(fwdCal.T)*np.array(bwdCal)
			#print("")
			#print(f"fwdBwdVal: {fwdBwdVal}")
			#print(f"fwdBwdValSum: {np.sum(fwdBwdVal)}")
			#print("")

			#@@ Probability normalization
			if (np.sum(fwdBwdVal)!=0 ):
				fwdBwdVal = fwdBwdVal/np.sum(fwdBwdVal)	# Normalize to be probability

			#@@ Find argmax and map to actual modulation data
			#@@ ex) NRZ: index0 -> 1, index1 -> 0

			if self.mod_=='nrz':
				if 1:
					fwdBwdValProb = np.zeros(2)
					for k, val in enumerate(fwdBwdVal):
						fwdBwdValProb[k%2] += val
					prob.append(fwdBwdValProb)
					#print (prob)
					#sys.exit()
				else:
					fwdBwdValOne = 0
					for k, val in enumerate(fwdBwdVal):
						fwdBwdValOne += (-(k%2)+1)*val	## Calculate probability of One
					prob.append(fwdBwdValOne[0])
				decBit = np.where(np.argmax(fwdBwdVal)%2==0, int(1),
															int(0)
															)
			elif self.mod_=='pam4':
				fwdBwdValProb = np.zeros(4)
				for k, val in enumerate(fwdBwdVal):
					fwdBwdValProb[k%4] +=val
				prob.append(fwdBwdValProb)
				decBit = np.where(np.argmax(fwdBwdVal)%4==0, int(3),
						np.where(np.argmax(fwdBwdVal)%4==1, int(2),
						np.where(np.argmax(fwdBwdVal)%4==2, int(1),
					              							int(0)
															)))
			elif self.mod_=='pam8':
				fwdBwdValProb = np.zeros(8)
				for k, val in enumerate(fwdBwdVal):
					fwdBwdValProb[k%8] +=val
				prob.append(fwdBwdValProb)
				decBit = np.where(np.argmax(fwdBwdVal)%8==0, int(7),
						np.where(np.argmax(fwdBwdVal)%8==1, int(6),
						np.where(np.argmax(fwdBwdVal)%8==2, int(5),
						np.where(np.argmax(fwdBwdVal)%8==3, int(4),
						np.where(np.argmax(fwdBwdVal)%8==4, int(3),
						np.where(np.argmax(fwdBwdVal)%8==5, int(2),
						np.where(np.argmax(fwdBwdVal)%8==6, int(1),
															int(0)
															)))))))




			seq.append(decBit)
			#print(f'prob:{prob} @t:{t}')
			#print(fwdBwdVal)
			#print(decBit)
			#if (t==0):
			#	print(f"t: {t}, \nrxSegment: {rxSegment}, \nprob: {prob[-1]}")

		#@@ Set last edge to garbage to match to input length
		seq = seq + [np.nan]*(lengthBwd)
		if self.mod_ == 'nrz':
			prob = prob+[[np.nan]*2]*(lengthBwd)
		elif self.mod_ == 'pam4':
			prob = prob+[[np.nan]*4]*(lengthBwd)
		elif self.mod_ == 'pam8':
			prob = prob+[[np.nan]*8]*(lengthBwd)


		#print(seq)
		#print(prob)

		return seq, prob




				#print(bwd)
		#			for s in range(numState):
		#				fwd[k][s] = 0
		#				for sPrev in range(numState):
		#					if (k==0):
		#						fwd[k][s] += 1*transit[sPrev][s] * gaus(rxSegment[k], self.expected_[s])
		#					else:
		#						fwd[k][s] += fwd[k-1]*transit[sPrev][s] * gaus(rxSegment[k], self.expected_[s])
		#	else:
		#		fwd[1:] = fwd[:-1]
		#		for s in range(numState):
		#			fwd[0][s] = 0
		#			for sPrev in range(numState):
		#				fwd[0][s] += 1*transit[sPrev][s] * gaus(rxSegment[0], self.expected_[s])





		#for t in range(1, len(rxIn)+1): 
		#	V.append({})
		#	newSeq = {}
		#	# find maximum scoring path for each state at each timestep
		#	for s in range(numState):
		#		(score, state) = max( (V[t-1][sPrev] + np.log(transit[sPrev][s] + 1e-15) - (self.expected_[s]-rxIn[t-1])**2
		#							, sPrev) for sPrev in range(numState) )
		#		V[t][s] = score
		#		if (t == 1):
		#			if (seqPrev is not None):
		#				newSeq[s] = seqPrev[state] + [s]
		#			else:
		#				newSeq[s] = [s]
		#		else:
		#			newSeq[s] = seq[state] + [s]
		#	seq = newSeq

		## find maximum scoring path at final timestep	
		#(score, state) = max( (V[t][s], s) for s in range(numState))
	
		#if self.mod_ == 'nrz':
		#	#if seqPrev is None:
		#	#	viterbiOut = (((np.array(seq[state][:-lengthOverlap]) & 1) + 1) % 2)*2 - 1
		#	#else:
		#	#	viterbiOut = (((np.array(seq[state][:-lengthOverlap]) &1 ) + 1) % 2)*2 - 1
		#	viterbiOut = (((np.array(seq[state][:-lengthOverlap]) &1 ) + 1) % 2)*2 - 1
		#	seqOverlap = {}
		#	#print ("1:",len(viterbiOut))
		#	#print ("2:",lengthOverlap)
		#	for s in range(numState):
		#		seqOverlap[s] = seq[s][-lengthOverlap:]

		#	return (score, viterbiOut , V[-1], seqOverlap)
		#else:
		#	print("Invalid modulation")
	


if __name__ == '__main__':
	dataSizeTrain=int(1e4)
	dataSizeTest=int(1e5)
	dataSizeTestFinal=int(100)
	#chSBR = [1.0,0.3,0.2,0.1]
	chSBR = [1.0,0.3,0.1]
	batchSize = 100
	seqLength = 10
	#chSBR = [1.0,0.4,0.2,0.1]
	#batchSize = 40 
	#inSize = 4 
	#outSize = 1
	#delay = 2
	snrTrain=10
	snrTest=10
	snrTestFinal=snrTest
	flagN = 0

	#############################################
	#### Test sequence for final evaluation #####
	#############################################
	chInTestFinal = np.array([],dtype=np.int)
	chInTestFinal = np.append(chInTestFinal, np.random.randint(2, size=dataSizeTestFinal))
	chInTestFinal = 2 * chInTestFinal - 1
	ch3 = Channel(sbr=chSBR, snr=snrTestFinal)
	chOutTestFinal = ch3.run(chIn = chInTestFinal, flagN=flagN)
	#print(f'channel_output: {chOutTestFinal}')
	

	test = eq(estTap=chSBR,snr=snrTest)
	#test.fir(rxIn=[0,0,0.2,0.4,1,0.3,0.2,0.1,0,0,0])
	#dfeOutReal, seq, delay= test.firDfe(rxIn=chOutTestFinal,ffeTapNum=5, dfeTapNum=5, maxTapNum=3)
	dfeOutReal, seq, delay = test.dfe(rxIn=chOutTestFinal,dfeTapNum=None)
	plt.legend(loc='best')
	plt.show()
	#print('seq')
	#print('chOutTestFinal')
	(ber,be) = test.berChecker(chInTestFinal, seq, delay=delay, log=True)

	#y=test.gaus(0,0,20)
#	(seq,prob) = test.fwdBwd(chOutTestFinal,10, snrTestFinal)
#	print ('')
#	print (seq)
#	print (prob)
#	print(f'{chInTestFinal}')
#	print(f'{seq-chInTestFinal}')
#	#print(y)
#	
